#include <gtest/gtest.h>
#include "c/ddk/graph/handle_types.h"
#include "c/ddk/graph/tensor.h"
#include "c/ddk/graph/attr_value.h"
#include "graph/attr_value.h"
#include "graph/csrc/resource_manager.h"
#include "c/ddk/graph/context.h"
#include "graph/types.h"
#include "graph/shape.h"

#include "framework/graph/utils/op_desc_utils.h"
#include "framework/graph/utils/attr_utils.h"

using namespace ge;
using namespace std;
using namespace hiai;

class UTEST_c_attr_value : public testing::Test {
public:
    void SetUp()
    {
        resMgr = HIAI_IR_ResourceManagerCreate();
    }

    void TearDown()
    {
        HIAI_IR_ResourceManagerDestroy(&resMgr);
    }
public:
    ResMgrHandle resMgr = nullptr;
};

TEST_F(UTEST_c_attr_value, CreateInt64Attr)
{
    int64_t value = 123;
    AttrHandle attrHandle = HIAI_IR_CreateInt64Attr(resMgr, value);
    EXPECT_NE(attrHandle, nullptr);

    auto resMgrPtr = reinterpret_cast<ResourceManager*>(resMgr);
    BasePtr basePtr = resMgrPtr->GetSrcPtr(attrHandle);
    std::shared_ptr<AttrValue> attrPtr = std::dynamic_pointer_cast<AttrValue>(basePtr);

    EXPECT_EQ(value, attrPtr->GetInt());
}

TEST_F(UTEST_c_attr_value, CreateBoolAttr)
{
    bool value = true;
    AttrHandle attrHandle = HIAI_IR_CreateBoolAttr(resMgr, value);
    EXPECT_NE(attrHandle, nullptr);

    auto resMgrPtr = reinterpret_cast<ResourceManager*>(resMgr);
    BasePtr basePtr = resMgrPtr->GetSrcPtr(attrHandle);
    std::shared_ptr<AttrValue> attrPtr = std::dynamic_pointer_cast<AttrValue>(basePtr);

    EXPECT_EQ(value, attrPtr->GetBool());
}

TEST_F(UTEST_c_attr_value, CreateFloatAttr)
{
    float value = 1.23;
    AttrHandle attrHandle = HIAI_IR_CreateFloatAttr(resMgr, value);
    EXPECT_NE(attrHandle, nullptr);

    auto resMgrPtr = reinterpret_cast<ResourceManager*>(resMgr);
    BasePtr basePtr = resMgrPtr->GetSrcPtr(attrHandle);
    std::shared_ptr<AttrValue> attrPtr = std::dynamic_pointer_cast<AttrValue>(basePtr);

    EXPECT_EQ(value, attrPtr->GetFloat());
}

TEST_F(UTEST_c_attr_value, CreateStringAttr)
{
    const char* value = "test_value";
    AttrHandle attrHandle = HIAI_IR_CreateStringAttr(resMgr, value);
    EXPECT_NE(attrHandle, nullptr);

    auto resMgrPtr = reinterpret_cast<ResourceManager*>(resMgr);
    BasePtr basePtr = resMgrPtr->GetSrcPtr(attrHandle);
    std::shared_ptr<AttrValue> attrPtr = std::dynamic_pointer_cast<AttrValue>(basePtr);

    EXPECT_EQ(value, attrPtr->GetString());
}

TEST_F(UTEST_c_attr_value, CreateTensorAttr)
{
    uint8_t* data = new uint8_t[4] {0x01, 0x02, 0x03, 0x04};
    size_t dataLen = 4;
    int64_t shape[] = {1, 2};
    size_t shapeSize = 2;
    HIAI_DataType typeC = HIAI_DATATYPE_FLOAT32;
    HIAI_Format formatC = HIAI_Format::HIAI_FORMAT_NCHW;

    TensorHandle tensor_handle = HIAI_IR_CreateTensor(resMgr, reinterpret_cast<void*>(data), dataLen,
        shape, shapeSize, typeC, formatC);
    EXPECT_NE(tensor_handle, nullptr);

    AttrHandle attrHandle = HIAI_IR_CreateTensorAttr(resMgr, tensor_handle);
    EXPECT_NE(attrHandle, nullptr);

    auto resMgrPtr = reinterpret_cast<ResourceManager*>(resMgr);
    BasePtr basePtr = resMgrPtr->GetSrcPtr(attrHandle);
    std::shared_ptr<AttrValue> attrPtr = std::dynamic_pointer_cast<AttrValue>(basePtr);
    std::shared_ptr<Tensor> tensorPtr = std::dynamic_pointer_cast<Tensor>(attrPtr->GetTensor());

    for (size_t i = 0; i < dataLen; ++i) {
        EXPECT_EQ(tensorPtr->GetData().data()[i], data[i]);
    }
    EXPECT_EQ(tensorPtr->GetData().size(), dataLen);

    std::vector<int64_t> shapeVec = tensorPtr->GetTensorDesc().GetShape().GetDims();
    EXPECT_EQ(shapeVec.size(), shapeSize);
    for (size_t i = 0; i < shapeSize; ++i) {
        EXPECT_EQ(shapeVec[i], shape[i]);
    }

    EXPECT_EQ(attrPtr->GetTensorDesc().GetDataType(), ge::DataType::DT_FLOAT);
    EXPECT_EQ(attrPtr->GetTensorDesc().GetFormat(), static_cast<ge::Format>(formatC));

    delete[] data;
}

TEST_F(UTEST_c_attr_value, CreateInt64LstAttr)
{
    int64_t value[] = {1, 2, 3};
    uint32_t num = 3;
    std::vector<int64_t> val(value, value + num);

    AttrHandle attrHandle = HIAI_IR_CreateInt64LstAttr(resMgr, val.data(), num);
    EXPECT_NE(attrHandle, nullptr);

    auto resMgrPtr = reinterpret_cast<ResourceManager*>(resMgr);
    BasePtr basePtr = resMgrPtr->GetSrcPtr(attrHandle);
    std::shared_ptr<AttrValue> attrPtr = std::dynamic_pointer_cast<AttrValue>(basePtr);

    EXPECT_EQ(val, attrPtr->GetIntList());
}

TEST_F(UTEST_c_attr_value, CreateBoolLstAttr)
{
    bool value[] = {true, false, true};
    uint32_t num = 3;
    std::vector<bool> val(value, value + num);

    AttrHandle attrHandle = HIAI_IR_CreateBoolLstAttr(resMgr, value, num);
    EXPECT_NE(attrHandle, nullptr);

    auto resMgrPtr = reinterpret_cast<ResourceManager*>(resMgr);
    BasePtr basePtr = resMgrPtr->GetSrcPtr(attrHandle);
    std::shared_ptr<AttrValue> attrPtr = std::dynamic_pointer_cast<AttrValue>(basePtr);

    EXPECT_EQ(val, attrPtr->GetBoolList());
}

TEST_F(UTEST_c_attr_value, CreateFloatLstAttr)
{
    float value[] = {1.0, 2.0, 3.0};
    uint32_t num = 3;
    std::vector<float> val(value, value + num);

    AttrHandle attrHandle = HIAI_IR_CreateFloatLstAttr(resMgr, value, num);
    EXPECT_NE(attrHandle, nullptr);

    auto resMgrPtr = reinterpret_cast<ResourceManager*>(resMgr);
    BasePtr basePtr = resMgrPtr->GetSrcPtr(attrHandle);
    std::shared_ptr<AttrValue> attrPtr = std::dynamic_pointer_cast<AttrValue>(basePtr);

    EXPECT_EQ(val, attrPtr->GetFloatList());
}

TEST_F(UTEST_c_attr_value, CreateStringLstAttr)
{
    const char *value[] = {"hello", "world"};
    size_t size = sizeof(value)/sizeof(value[0]);
    std::vector<std::string> val;
    for (size_t i = 0; i < size; i++) {
        val.push_back(value[i]);
    }

    AttrHandle attrHandle = HIAI_IR_CreateStringLstAttr(resMgr, value, size);
    EXPECT_NE(attrHandle, nullptr);

    auto resMgrPtr = reinterpret_cast<ResourceManager*>(resMgr);
    BasePtr basePtr = resMgrPtr->GetSrcPtr(attrHandle);
    std::shared_ptr<AttrValue> attrPtr = std::dynamic_pointer_cast<AttrValue>(basePtr);

    EXPECT_EQ(val, attrPtr->GetStringList());
}
